{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "53538f91-97bc-4026-8048-d7742ea3db16",
   "metadata": {},
   "source": [
    "# 基于 Chat Completions API 实现外部函数调用\n",
    "\n",
    "**2023年6月20日，OpenAI 官方在 Chat Completions API 原有的三种不同角色设定（System, Assistant, User）基础上，新增了 Function Calling 功能。**\n",
    "\n",
    "![function_calling_openai_blog](images/function_calling_openai_blog.png)\n",
    "\n",
    "**[详见OpenAI Blog](https://openai.com/blog/function-calling-and-other-api-updates)**"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d8ead58e-35bf-49e5-b21d-4ba2ac299d5a",
   "metadata": {},
   "source": [
    "\n",
    "`functions` 是 Chat Completion API 中的可选参数，用于提供函数定义。其目的是使 GPT 模型能够生成符合所提供定义的函数参数。请注意，API不会实际执行任何函数调用。开发人员需要使用GPT 模型输出来执行函数调用。\n",
    "\n",
    "如果提供了`functions`参数，默认情况下，GPT 模型将决定在何时适当地使用其中一个函数。\n",
    "\n",
    "可以通过将`function_call`参数设置为`{\"name\": \"<insert-function-name>\"}`来强制 API 使用指定函数。\n",
    "\n",
    "同时，也支持通过将`function_call`参数设置为`\"none\"`来强制API不使用任何函数。\n",
    "\n",
    "如果使用了某个函数，则响应中的输出将包含`\"finish_reason\": \"function_call\"`，以及一个具有该函数名称和生成的函数参数的`function_call`对象。\n",
    "\n",
    "\n",
    "![function_calling](images/function_calling.png)\n",
    "\n",
    "\n",
    "## 概述\n",
    "\n",
    "本 Notebook 介绍了如何将 Chat Completions API 与外部函数结合使用，以扩展 GPT 模型的功能。包含以下2个部分：\n",
    "- 如何使用 `functions` 参数\n",
    "- 如何使用 `function_call` 参数\n",
    "- 使用 GPT 模型生成函数和参数\n",
    "- 实际执行 GPT 模型生成的函数（以 SQL 查询为例）\n",
    "\n",
    "### 注意：本示例直接构造 HTTP 请求访问 OpenAI API，因此无需使用 openai Python SDK。"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "64c85e26",
   "metadata": {},
   "source": [
    "## 安装依赖包"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "80e71f33",
   "metadata": {
    "pycharm": {
     "is_executing": true
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Requirement already satisfied: scipy in /root/miniconda3/envs/langchain/lib/python3.10/site-packages (1.11.1)\n",
      "Requirement already satisfied: tenacity in /root/miniconda3/envs/langchain/lib/python3.10/site-packages (8.2.2)\n",
      "Requirement already satisfied: tiktoken in /root/miniconda3/envs/langchain/lib/python3.10/site-packages (0.5.1)\n",
      "Requirement already satisfied: termcolor in /root/miniconda3/envs/langchain/lib/python3.10/site-packages (2.3.0)\n",
      "Requirement already satisfied: openai in /root/miniconda3/envs/langchain/lib/python3.10/site-packages (1.3.4)\n",
      "Requirement already satisfied: requests in /root/miniconda3/envs/langchain/lib/python3.10/site-packages (2.31.0)\n",
      "Requirement already satisfied: numpy<1.28.0,>=1.21.6 in /root/miniconda3/envs/langchain/lib/python3.10/site-packages (from scipy) (1.26.2)\n",
      "Requirement already satisfied: regex>=2022.1.18 in /root/miniconda3/envs/langchain/lib/python3.10/site-packages (from tiktoken) (2023.5.5)\n",
      "Requirement already satisfied: pydantic<3,>=1.9.0 in /root/miniconda3/envs/langchain/lib/python3.10/site-packages (from openai) (1.10.8)\n",
      "Requirement already satisfied: typing-extensions<5,>=4.5 in /root/miniconda3/envs/langchain/lib/python3.10/site-packages (from openai) (4.6.2)\n",
      "Requirement already satisfied: tqdm>4 in /root/miniconda3/envs/langchain/lib/python3.10/site-packages (from openai) (4.65.0)\n",
      "Requirement already satisfied: anyio<4,>=3.5.0 in /root/miniconda3/envs/langchain/lib/python3.10/site-packages (from openai) (3.6.2)\n",
      "Requirement already satisfied: distro<2,>=1.7.0 in /root/miniconda3/envs/langchain/lib/python3.10/site-packages (from openai) (1.8.0)\n",
      "Requirement already satisfied: httpx<1,>=0.23.0 in /root/miniconda3/envs/langchain/lib/python3.10/site-packages (from openai) (0.23.3)\n",
      "Requirement already satisfied: charset-normalizer<4,>=2 in /root/miniconda3/envs/langchain/lib/python3.10/site-packages (from requests) (3.1.0)\n",
      "Requirement already satisfied: certifi>=2017.4.17 in /root/miniconda3/envs/langchain/lib/python3.10/site-packages (from requests) (2023.5.7)\n",
      "Requirement already satisfied: urllib3<3,>=1.21.1 in /root/miniconda3/envs/langchain/lib/python3.10/site-packages (from requests) (1.26.16)\n",
      "Requirement already satisfied: idna<4,>=2.5 in /root/miniconda3/envs/langchain/lib/python3.10/site-packages (from requests) (3.4)\n",
      "Requirement already satisfied: sniffio>=1.1 in /root/miniconda3/envs/langchain/lib/python3.10/site-packages (from anyio<4,>=3.5.0->openai) (1.3.0)\n",
      "Requirement already satisfied: rfc3986[idna2008]<2,>=1.3 in /root/miniconda3/envs/langchain/lib/python3.10/site-packages (from httpx<1,>=0.23.0->openai) (1.5.0)\n",
      "Requirement already satisfied: httpcore<0.17.0,>=0.15.0 in /root/miniconda3/envs/langchain/lib/python3.10/site-packages (from httpx<1,>=0.23.0->openai) (0.16.3)\n",
      "Requirement already satisfied: h11<0.15,>=0.13 in /root/miniconda3/envs/langchain/lib/python3.10/site-packages (from httpcore<0.17.0,>=0.15.0->httpx<1,>=0.23.0->openai) (0.14.0)\n",
      "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n",
      "\u001b[0m"
     ]
    }
   ],
   "source": [
    "!pip install scipy tenacity tiktoken termcolor openai requests"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "dab872c5",
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "import requests\n",
    "import os\n",
    "from tenacity import retry, wait_random_exponential, stop_after_attempt\n",
    "from termcolor import colored\n",
    "\n",
    "GPT_MODEL = \"gpt-3.5-turbo\""
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "69ee6a93",
   "metadata": {},
   "source": [
    "### 定义工具函数\n",
    "\n",
    "首先，让我们定义一些用于调用聊天完成 API 的实用工具，并维护和跟踪对话状态。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "eb420fcd-1a4d-4cc8-90cf-39a92cf28060",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 使用了retry库，指定在请求失败时的重试策略。\n",
    "# 这里设定的是指数等待（wait_random_exponential），时间间隔的最大值为40秒，并且最多重试3次（stop_after_attempt(3)）。\n",
    "# 定义一个函数chat_completion_request，主要用于发送 聊天补全 请求到OpenAI服务器\n",
    "@retry(wait=wait_random_exponential(multiplier=1, max=40), stop=stop_after_attempt(3))\n",
    "def chat_completion_request(messages, functions=None, function_call=None, model=GPT_MODEL):\n",
    "\n",
    "    # 设定请求的header信息，包括 API_KEY\n",
    "    headers = {\n",
    "        \"Content-Type\": \"application/json\",\n",
    "        \"Authorization\": \"Bearer \" + os.getenv(\"OPENAI_API_KEY\"),\n",
    "    }\n",
    "\n",
    "    # 设定请求的JSON数据，包括GPT 模型名和要进行补全的消息\n",
    "    json_data = {\"model\": model, \"messages\": messages}\n",
    "\n",
    "    # 如果传入了functions，将其加入到json_data中\n",
    "    if functions is not None:\n",
    "        json_data.update({\"functions\": functions})\n",
    "\n",
    "    # 如果传入了function_call，将其加入到json_data中\n",
    "    if function_call is not None:\n",
    "        json_data.update({\"function_call\": function_call})\n",
    "\n",
    "    # 尝试发送POST请求到OpenAI服务器的chat/completions接口\n",
    "    try:\n",
    "        response = requests.post(\n",
    "            \"https://api.openai.com/v1/chat/completions\",\n",
    "            headers=headers,\n",
    "            json=json_data,\n",
    "        )\n",
    "        # 返回服务器的响应\n",
    "        return response\n",
    "\n",
    "    # 如果发送请求或处理响应时出现异常，打印异常信息并返回\n",
    "    except Exception as e:\n",
    "        print(\"Unable to generate ChatCompletion response\")\n",
    "        print(f\"Exception: {e}\")\n",
    "        return e\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "1519c8f7-fde1-44ce-b7cc-fd837905fed4",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 定义一个函数pretty_print_conversation，用于打印消息对话内容\n",
    "def pretty_print_conversation(messages):\n",
    "\n",
    "    # 为不同角色设置不同的颜色\n",
    "    role_to_color = {\n",
    "        \"system\": \"red\",\n",
    "        \"user\": \"green\",\n",
    "        \"assistant\": \"blue\",\n",
    "        \"function\": \"magenta\",\n",
    "    }\n",
    "\n",
    "    # 遍历消息列表\n",
    "    for message in messages:\n",
    "\n",
    "        # 如果消息的角色是\"system\"，则用红色打印“content”\n",
    "        if message[\"role\"] == \"system\":\n",
    "            print(colored(f\"system: {message['content']}\\n\", role_to_color[message[\"role\"]]))\n",
    "\n",
    "        # 如果消息的角色是\"user\"，则用绿色打印“content”\n",
    "        elif message[\"role\"] == \"user\":\n",
    "            print(colored(f\"user: {message['content']}\\n\", role_to_color[message[\"role\"]]))\n",
    "\n",
    "        # 如果消息的角色是\"assistant\"，并且消息中包含\"function_call\"，则用蓝色打印\"function_call\"\n",
    "        elif message[\"role\"] == \"assistant\" and message.get(\"function_call\"):\n",
    "            print(colored(f\"assistant[function_call]: {message['function_call']}\\n\", role_to_color[message[\"role\"]]))\n",
    "\n",
    "        # 如果消息的角色是\"assistant\"，但是消息中不包含\"function_call\"，则用蓝色打印“content”\n",
    "        elif message[\"role\"] == \"assistant\" and not message.get(\"function_call\"):\n",
    "            print(colored(f\"assistant[content]: {message['content']}\\n\", role_to_color[message[\"role\"]]))\n",
    "\n",
    "        # 如果消息的角色是\"function\"，则用品红色打印“function”\n",
    "        elif message[\"role\"] == \"function\":\n",
    "            print(colored(f\"function ({message['name']}): {message['content']}\\n\", role_to_color[message[\"role\"]]))\n"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "29d4e02b",
   "metadata": {},
   "source": [
    "### 如何使用 functions 参数\n",
    "\n",
    "这段代码定义了两个可以在程序中调用的函数，分别是获取当前天气和获取未来N天的天气预报。\n",
    "\n",
    "每个函数(function)都有其名称、描述和需要的参数（包括参数的类型、描述等信息）。\n",
    "\n",
    "![functions_param](images/functions_param.png)\n",
    "\n",
    "我们将把这些传递给 Chat Completions API，以生成符合规范的函数。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "9f172ed2-d0a8-4aaf-9ed3-a3cae53ed4e5",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 定义一个名为functions的列表，其中包含两个字典，这两个字典分别定义了两个功能的相关参数\n",
    "\n",
    "# 第一个字典定义了一个名为\"get_current_weather\"的功能\n",
    "functions = [\n",
    "    {\n",
    "        \"name\": \"get_current_weather\",  # 功能的名称\n",
    "        \"description\": \"Get the current weather\",  # 功能的描述\n",
    "        \"parameters\": {  # 定义该功能需要的参数\n",
    "            \"type\": \"object\",\n",
    "            \"properties\": {  # 参数的属性\n",
    "                \"location\": {  # 地点参数\n",
    "                    \"type\": \"string\",  # 参数类型为字符串\n",
    "                    \"description\": \"The city and state, e.g. San Francisco, CA\",  # 参数的描述\n",
    "                },\n",
    "                \"format\": {  # 温度单位参数\n",
    "                    \"type\": \"string\",  # 参数类型为字符串\n",
    "                    \"enum\": [\"celsius\", \"fahrenheit\"],  # 参数的取值范围\n",
    "                    \"description\": \"The temperature unit to use. Infer this from the users location.\",  # 参数的描述\n",
    "                },\n",
    "            },\n",
    "            \"required\": [\"location\", \"format\"],  # 该功能需要的必要参数\n",
    "        },\n",
    "    },\n",
    "    # 第二个字典定义了一个名为\"get_n_day_weather_forecast\"的功能\n",
    "    {\n",
    "        \"name\": \"get_n_day_weather_forecast\",  # 功能的名称\n",
    "        \"description\": \"Get an N-day weather forecast\",  # 功能的描述\n",
    "        \"parameters\": {  # 定义该功能需要的参数\n",
    "            \"type\": \"object\",\n",
    "            \"properties\": {  # 参数的属性\n",
    "                \"location\": {  # 地点参数\n",
    "                    \"type\": \"string\",  # 参数类型为字符串\n",
    "                    \"description\": \"The city and state, e.g. San Francisco, CA\",  # 参数的描述\n",
    "                },\n",
    "                \"format\": {  # 温度单位参数\n",
    "                    \"type\": \"string\",  # 参数类型为字符串\n",
    "                    \"enum\": [\"celsius\", \"fahrenheit\"],  # 参数的取值范围\n",
    "                    \"description\": \"The temperature unit to use. Infer this from the users location.\",  # 参数的描述\n",
    "                },\n",
    "                \"num_days\": {  # 预测天数参数\n",
    "                    \"type\": \"integer\",  # 参数类型为整数\n",
    "                    \"description\": \"The number of days to forecast\",  # 参数的描述\n",
    "                }\n",
    "            },\n",
    "            \"required\": [\"location\", \"format\", \"num_days\"]  # 该功能需要的必要参数\n",
    "        },\n",
    "    },\n",
    "]\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d2e25069",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "bfc39899",
   "metadata": {},
   "source": [
    "这段代码首先定义了一个`messages`列表用来存储聊天的消息，然后向列表中添加了系统和用户的消息。\n",
    "\n",
    "然后，它使用了之前定义的`chat_completion_request`函数发送一个请求，传入的参数包括消息列表和函数列表。\n",
    "\n",
    "在接收到响应后，它从JSON响应中解析出助手的消息，并将其添加到消息列表中。\n",
    "\n",
    "最后，它打印出 GPT 模型回复的消息。\n",
    "\n",
    "**（如果我们询问当前天气，GPT 模型会回复让你给出更准确的问题。）**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "518d6827",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\u001b[31msystem: Don't make assumptions about what values to plug into functions. Ask for clarification if a user request is ambiguous.\n",
      "\u001b[0m\n",
      "\u001b[32muser: What's the weather like today\n",
      "\u001b[0m\n",
      "\u001b[34massistant[content]: Sure, can you please provide me with your location?\n",
      "\u001b[0m\n"
     ]
    }
   ],
   "source": [
    "# 定义一个空列表messages，用于存储聊天的内容\n",
    "messages = []\n",
    "\n",
    "# 使用append方法向messages列表添加一条系统角色的消息\n",
    "messages.append({\n",
    "    \"role\": \"system\",  # 消息的角色是\"system\"\n",
    "    \"content\": \"Don't make assumptions about what values to plug into functions. Ask for clarification if a user request is ambiguous.\"  # 消息的内容\n",
    "})\n",
    "\n",
    "# 向messages列表添加一条用户角色的消息\n",
    "messages.append({\n",
    "    \"role\": \"user\",  # 消息的角色是\"user\"\n",
    "    \"content\": \"What's the weather like today\"  # 用户询问今天的天气情况\n",
    "})\n",
    "\n",
    "# 使用定义的chat_completion_request函数发起一个请求，传入messages和functions作为参数\n",
    "chat_response = chat_completion_request(\n",
    "    messages, functions=functions\n",
    ")\n",
    "\n",
    "# 解析返回的JSON数据，获取助手的回复消息\n",
    "assistant_message = chat_response.json()[\"choices\"][0][\"message\"]\n",
    "\n",
    "# 将助手的回复消息添加到messages列表中\n",
    "messages.append(assistant_message)\n",
    "\n",
    "pretty_print_conversation(messages)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ec40cddb-0746-4caf-bb32-46b3353ef089",
   "metadata": {},
   "source": [
    "**(我们需要提供更详细的信息，以便于 GPT 模型为我们生成适当的函数和对应参数。)**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "7110fdb9-d791-4e7e-8d1e-8786bf9f6d82",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "dict"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "type(assistant_message)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3b7114f4-ca38-445f-afef-134fccd79039",
   "metadata": {},
   "source": [
    "## 使用 GPT 模型生成函数和对应参数"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "4c999375",
   "metadata": {},
   "source": [
    "下面这段代码先向messages列表中添加了用户的位置信息。\n",
    "\n",
    "然后再次使用了chat_completion_request函数发起请求，只是这次传入的消息列表已经包括了用户的新消息。\n",
    "\n",
    "在获取到响应后，它同样从JSON响应中解析出助手的消息，并将其添加到消息列表中。\n",
    "\n",
    "最后，打印出助手的新的回复消息。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "59d15e8d-f4ab-48c8-aadc-1ad994d8e3da",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\u001b[31msystem: Don't make assumptions about what values to plug into functions. Ask for clarification if a user request is ambiguous.\n",
      "\u001b[0m\n",
      "\u001b[32muser: What's the weather like today\n",
      "\u001b[0m\n",
      "\u001b[34massistant[content]: Sure, can you please provide me with your location?\n",
      "\u001b[0m\n",
      "\u001b[32muser: I'm in Shanghai, China.\n",
      "\u001b[0m\n",
      "\u001b[34massistant[function_call]: {'name': 'get_current_weather', 'arguments': '{\\n  \"location\": \"Shanghai, China\",\\n  \"format\": \"celsius\"\\n}'}\n",
      "\u001b[0m\n"
     ]
    }
   ],
   "source": [
    "# 向messages列表添加一条用户角色的消息，用户告知他们在苏格兰的格拉斯哥\n",
    "messages.append({\n",
    "    \"role\": \"user\",  # 消息的角色是\"user\"\n",
    "    \"content\": \"I'm in Shanghai, China.\"  # 用户的消息内容\n",
    "})\n",
    "\n",
    "# 再次使用定义的chat_completion_request函数发起一个请求，传入更新后的messages和functions作为参数\n",
    "chat_response = chat_completion_request(\n",
    "    messages, functions=functions\n",
    ")\n",
    "\n",
    "# 解析返回的JSON数据，获取助手的新的回复消息\n",
    "assistant_message = chat_response.json()[\"choices\"][0][\"message\"]\n",
    "\n",
    "# 将助手的新的回复消息添加到messages列表中\n",
    "messages.append(assistant_message)\n",
    "\n",
    "pretty_print_conversation(messages)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "c14d4762",
   "metadata": {},
   "source": [
    "这段代码的逻辑大体与上一段代码相同，区别在于这次用户的询问中涉及到未来若干天（x天）的天气预报。\n",
    "\n",
    "在获取到回复后，它同样从JSON响应中解析出助手的消息，并将其添加到消息列表中。\n",
    "\n",
    "然后打印出助手的回复消息。\n",
    "\n",
    "**（通过不同的prompt方式，我们可以让它针对我们告诉它的其他功能。）**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "fa232e54",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\u001b[31msystem: Don't make assumptions about what values to plug into functions. Ask for clarification if a user request is ambiguous.\n",
      "\u001b[0m\n",
      "\u001b[32muser: what is the weather going to be like in Shanghai, China over the next x days\n",
      "\u001b[0m\n",
      "\u001b[34massistant[content]: Sure, I can help you with that. Please provide the number of days you would like the weather forecast for.\n",
      "\u001b[0m\n"
     ]
    }
   ],
   "source": [
    "# 初始化一个空的messages列表\n",
    "messages = []\n",
    "\n",
    "# 向messages列表添加一条系统角色的消息，要求不做关于函数参数值的假设，如果用户的请求模糊，应该寻求澄清\n",
    "messages.append({\n",
    "    \"role\": \"system\",  # 消息的角色是\"system\"\n",
    "    \"content\": \"Don't make assumptions about what values to plug into functions. Ask for clarification if a user request is ambiguous.\"\n",
    "})\n",
    "\n",
    "# 向messages列表添加一条用户角色的消息，用户询问在未来x天内苏格兰格拉斯哥的天气情况\n",
    "messages.append({\n",
    "    \"role\": \"user\",  # 消息的角色是\"user\"\n",
    "    \"content\": \"what is the weather going to be like in Shanghai, China over the next x days\"\n",
    "})\n",
    "\n",
    "# 使用定义的chat_completion_request函数发起一个请求，传入messages和functions作为参数\n",
    "chat_response = chat_completion_request(\n",
    "    messages, functions=functions\n",
    ")\n",
    "\n",
    "# 解析返回的JSON数据，获取助手的回复消息\n",
    "assistant_message = chat_response.json()[\"choices\"][0][\"message\"]\n",
    "\n",
    "# 将助手的回复消息添加到messages列表中\n",
    "messages.append(assistant_message)\n",
    "\n",
    "# 打印助手的回复消息\n",
    "pretty_print_conversation(messages)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "6172ddac",
   "metadata": {},
   "source": [
    "**(GPT 模型再次要求我们澄清，因为它还没有足够的信息。在这种情况下，它已经知道预测的位置，但需要知道需要多少天的预测。)**\n",
    "\n",
    "这段代码的主要目标是将用户指定的天数（5天）添加到消息列表中，然后再次调用chat_completion_request函数发起一个请求。\n",
    "\n",
    "返回的响应中包含了助手对用户的回复，即未来5天的天气预报。\n",
    "\n",
    "这个预报是基于用户指定的地点（上海）和天数（5天）生成的。\n",
    "\n",
    "在代码的最后，它解析出返回的JSON响应中的第一个选项，这就是助手的回复消息。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "c7d8a543",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\u001b[31msystem: Don't make assumptions about what values to plug into functions. Ask for clarification if a user request is ambiguous.\n",
      "\u001b[0m\n",
      "\u001b[32muser: what is the weather going to be like in Shanghai, China over the next x days\n",
      "\u001b[0m\n",
      "\u001b[34massistant[content]: Sure, I can help you with that. Please provide the number of days you would like the weather forecast for.\n",
      "\u001b[0m\n",
      "\u001b[32muser: 5 days\n",
      "\u001b[0m\n",
      "\u001b[34massistant[function_call]: {'name': 'get_n_day_weather_forecast', 'arguments': '{\\n  \"location\": \"Shanghai, China\",\\n  \"format\": \"celsius\",\\n  \"num_days\": 5\\n}'}\n",
      "\u001b[0m\n"
     ]
    }
   ],
   "source": [
    "# 向messages列表添加一条用户角色的消息，用户指定接下来的天数为5天\n",
    "messages.append({\n",
    "    \"role\": \"user\",  # 消息的角色是\"user\"\n",
    "    \"content\": \"5 days\"\n",
    "})\n",
    "\n",
    "# 使用定义的chat_completion_request函数发起一个请求，传入messages和functions作为参数\n",
    "chat_response = chat_completion_request(\n",
    "    messages, functions=functions\n",
    ")\n",
    "\n",
    "# 解析返回的JSON数据，获取第一个选项\n",
    "assistant_message = chat_response.json()[\"choices\"][0][\"message\"]\n",
    "\n",
    "# 将助手的回复消息添加到messages列表中\n",
    "messages.append(assistant_message)\n",
    "\n",
    "# 打印助手的回复消息\n",
    "pretty_print_conversation(messages)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "4b758a0a",
   "metadata": {},
   "source": [
    "#### 强制使用指定函数"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "412f79ba",
   "metadata": {},
   "source": [
    "我们可以通过使用`function_call`参数来强制GPT 模型使用指定函数，例如`get_n_day_weather_forecast`。\n",
    "\n",
    "通过这种方式，可以让 GPT 模型学习如何使用该函数。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "559371b7",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\u001b[31msystem: Don't make assumptions about what values to plug into functions. Ask for clarification if a user request is ambiguous.\n",
      "\u001b[0m\n",
      "\u001b[32muser: Give me a weather report for San Diego, USA.\n",
      "\u001b[0m\n",
      "\u001b[34massistant[function_call]: {'name': 'get_n_day_weather_forecast', 'arguments': '{\\n  \"location\": \"San Diego, USA\",\\n  \"format\": \"celsius\",\\n  \"num_days\": 1\\n}'}\n",
      "\u001b[0m\n"
     ]
    }
   ],
   "source": [
    "# 在这个代码单元中，我们强制GPT 模型使用get_n_day_weather_forecast函数\n",
    "messages = []  # 创建一个空的消息列表\n",
    "\n",
    "# 添加系统角色的消息\n",
    "messages.append({\n",
    "    \"role\": \"system\",  # 角色为系统\n",
    "    \"content\": \"Don't make assumptions about what values to plug into functions. Ask for clarification if a user request is ambiguous.\"\n",
    "})\n",
    "\n",
    "# 添加用户角色的消息\n",
    "messages.append({\n",
    "    \"role\": \"user\",  # 角色为用户\n",
    "    \"content\": \"Give me a weather report for San Diego, USA.\"\n",
    "})\n",
    "\n",
    "# 使用定义的chat_completion_request函数发起一个请求，传入messages、functions以及特定的function_call作为参数\n",
    "chat_response = chat_completion_request(\n",
    "    messages, functions=functions, function_call={\"name\": \"get_n_day_weather_forecast\"}\n",
    ")\n",
    "\n",
    "# 解析返回的JSON数据，获取第一个选项\n",
    "assistant_message = chat_response.json()[\"choices\"][0][\"message\"]\n",
    "\n",
    "# 将助手的回复消息添加到messages列表中\n",
    "messages.append(assistant_message)\n",
    "\n",
    "# 打印助手的回复消息\n",
    "pretty_print_conversation(messages)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f27e822c-5e19-4bff-a44c-6e0d57fbc63e",
   "metadata": {},
   "source": [
    "下面这段代码演示了在不强制使用特定函数（`get_n_day_weather_forecast`）的情况下，GPT 模型可能会选择不同的方式来回应用户的请求。对于给定的用户请求\"Give me a weather report for San Diego, USA.\"，GPT 模型可能不会调用`get_n_day_weather_forecast`函数。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "3c9b9446-3193-4332-98e6-17f155bd5824",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\u001b[31msystem: Don't make assumptions about what values to plug into functions. Ask for clarification if a user request is ambiguous.\n",
      "\u001b[0m\n",
      "\u001b[32muser: Give me a weather report for San Diego, USA.\n",
      "\u001b[0m\n",
      "\u001b[34massistant[function_call]: {'name': 'get_current_weather', 'arguments': '{\\n  \"location\": \"San Diego, USA\",\\n  \"format\": \"celsius\"\\n}'}\n",
      "\u001b[0m\n"
     ]
    }
   ],
   "source": [
    "# 如果我们不强制GPT 模型使用 get_n_day_weather_forecast，它可能不会使用\n",
    "messages = []  # 创建一个空的消息列表\n",
    "\n",
    "# 添加系统角色的消息\n",
    "messages.append({\n",
    "    \"role\": \"system\",  # 角色为系统\n",
    "    \"content\": \"Don't make assumptions about what values to plug into functions. Ask for clarification if a user request is ambiguous.\"\n",
    "})\n",
    "\n",
    "# 添加用户角色的消息\n",
    "messages.append({\n",
    "    \"role\": \"user\",  # 角色为用户\n",
    "    \"content\": \"Give me a weather report for San Diego, USA.\"\n",
    "})\n",
    "\n",
    "# 使用定义的chat_completion_request函数发起一个请求，传入messages和functions作为参数\n",
    "chat_response = chat_completion_request(\n",
    "    messages, functions=functions\n",
    ")\n",
    "\n",
    "# 解析返回的JSON数据，获取第一个选项\n",
    "assistant_message = chat_response.json()[\"choices\"][0][\"message\"]\n",
    "\n",
    "# 将助手的回复消息添加到messages列表中\n",
    "messages.append(assistant_message)\n",
    "\n",
    "# 打印助手的回复消息\n",
    "pretty_print_conversation(messages)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "3bd70e48",
   "metadata": {},
   "source": [
    "#### 强制不使用函数\n",
    "\n",
    "然后，我们创建另一个消息列表，并添加系统和用户的消息。这次用户请求的是加拿大多伦多当前的天气（使用摄氏度）。\n",
    "\n",
    "随后，代码再次调用`chat_completion_request`函数，\n",
    "\n",
    "但这次在`function_call`参数中明确指定了\"none\"，表示GPT 模型在处理此请求时不能调用任何函数。\n",
    "\n",
    "最后，代码解析返回的JSON响应，获取第一个选项的消息，即 GPT 模型的回应。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "aa95c63b-ee7a-41dc-b1a3-6173edb9f74e",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\u001b[31msystem: Don't make assumptions about what values to plug into functions. Ask for clarification if a user request is ambiguous.\n",
      "\u001b[0m\n",
      "\u001b[32muser: Give me the current weather (use Celcius) for Toronto, Canada.\n",
      "\u001b[0m\n",
      "\u001b[34massistant[content]: {\n",
      "  \"location\": \"Toronto, Canada\",\n",
      "  \"format\": \"celsius\"\n",
      "}\n",
      "\u001b[0m\n"
     ]
    }
   ],
   "source": [
    "# 创建另一个空的消息列表\n",
    "messages = []\n",
    "\n",
    "# 添加系统角色的消息\n",
    "messages.append({\n",
    "    \"role\": \"system\",  # 角色为系统\n",
    "    \"content\": \"Don't make assumptions about what values to plug into functions. Ask for clarification if a user request is ambiguous.\"\n",
    "})\n",
    "\n",
    "# 添加用户角色的消息\n",
    "messages.append({\n",
    "    \"role\": \"user\",  # 角色为用户\n",
    "    \"content\": \"Give me the current weather (use Celcius) for Toronto, Canada.\"\n",
    "})\n",
    "\n",
    "# 使用定义的chat_completion_request函数发起一个请求，传入messages、functions和function_call作为参数\n",
    "chat_response = chat_completion_request(\n",
    "    messages, functions=functions, function_call=\"none\"\n",
    ")\n",
    "\n",
    "# 解析返回的JSON数据，获取第一个选项\n",
    "assistant_message = chat_response.json()[\"choices\"][0][\"message\"]\n",
    "\n",
    "# 将助手的回复消息添加到messages列表中\n",
    "messages.append(assistant_message)\n",
    "\n",
    "# 打印助手的回复消息\n",
    "pretty_print_conversation(messages)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "b4482aee",
   "metadata": {},
   "source": [
    "## 执行 GPT 模型生成的函数\n",
    "\n",
    "接着，我们将演示如何执行输入为 GPT 模型生成的函数，并利用这一点来实现一个可以帮助我们回答关于数据库的问题的代理。\n",
    "\n",
    "为了简单起见，我们将使用[Chinook样本数据库](https://www.sqlitetutorial.net/sqlite-sample-database/)。\n",
    "\n",
    "![chinook_db](images/chinook_db.jpeg)\n",
    "\n",
    "*注意：* 在生产环境中，SQL生成可能存在较高风险，因为GPT 模型在生成正确的SQL方面并不完全可靠。"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "f7654fef",
   "metadata": {},
   "source": [
    "### 定义一个执行SQL查询的函数\n",
    "\n",
    "首先，让我们定义一些有用的实用函数来从SQLite数据库中提取数据。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "30f6b60e",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Opened database successfully\n"
     ]
    }
   ],
   "source": [
    "import sqlite3\n",
    "\n",
    "conn = sqlite3.connect(\"data/chinook.db\")\n",
    "print(\"Opened database successfully\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "61ebf7cc-5d2a-463e-804c-849bf7e276c4",
   "metadata": {},
   "source": [
    "![chinook](images/chinook.png)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b2b315f2-6c67-4548-88ac-673f840926a8",
   "metadata": {},
   "source": [
    "首先定义三个函数`get_table_names`、`get_column_names`和`get_database_info`，用于从数据库连接对象中获取数据库的表名、表的列名以及整体数据库的信息。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "abec0214",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_table_names(conn):\n",
    "    \"\"\"返回一个包含所有表名的列表\"\"\"\n",
    "    table_names = []  # 创建一个空的表名列表\n",
    "    # 执行SQL查询，获取数据库中所有表的名字\n",
    "    tables = conn.execute(\"SELECT name FROM sqlite_master WHERE type='table';\")\n",
    "    # 遍历查询结果，并将每个表名添加到列表中\n",
    "    for table in tables.fetchall():\n",
    "        table_names.append(table[0])\n",
    "    return table_names  # 返回表名列表\n",
    "\n",
    "\n",
    "def get_column_names(conn, table_name):\n",
    "    \"\"\"返回一个给定表的所有列名的列表\"\"\"\n",
    "    column_names = []  # 创建一个空的列名列表\n",
    "    # 执行SQL查询，获取表的所有列的信息\n",
    "    columns = conn.execute(f\"PRAGMA table_info('{table_name}');\").fetchall()\n",
    "    # 遍历查询结果，并将每个列名添加到列表中\n",
    "    for col in columns:\n",
    "        column_names.append(col[1])\n",
    "    return column_names  # 返回列名列表\n",
    "\n",
    "\n",
    "def get_database_info(conn):\n",
    "    \"\"\"返回一个字典列表，每个字典包含一个表的名字和列信息\"\"\"\n",
    "    table_dicts = []  # 创建一个空的字典列表\n",
    "    # 遍历数据库中的所有表\n",
    "    for table_name in get_table_names(conn):\n",
    "        columns_names = get_column_names(conn, table_name)  # 获取当前表的所有列名\n",
    "        # 将表名和列名信息作为一个字典添加到列表中\n",
    "        table_dicts.append({\"table_name\": table_name, \"column_names\": columns_names})\n",
    "    return table_dicts  # 返回字典列表"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "77e6e5ea",
   "metadata": {},
   "source": [
    "将数据库信息转换为 Python 字典类型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "0c0104cd",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 获取数据库信息，并存储为字典列表\n",
    "database_schema_dict = get_database_info(conn)\n",
    "\n",
    "# 将数据库信息转换为字符串格式，方便后续使用\n",
    "database_schema_string = \"\\n\".join(\n",
    "    [\n",
    "        f\"Table: {table['table_name']}\\nColumns: {', '.join(table['column_names'])}\"\n",
    "        for table in database_schema_dict\n",
    "    ]\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "2daf9a2d-c4e0-4f56-9d01-35b8df25753e",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[{'table_name': 'albums', 'column_names': ['AlbumId', 'Title', 'ArtistId']},\n",
       " {'table_name': 'sqlite_sequence', 'column_names': ['name', 'seq']},\n",
       " {'table_name': 'artists', 'column_names': ['ArtistId', 'Name']},\n",
       " {'table_name': 'customers',\n",
       "  'column_names': ['CustomerId',\n",
       "   'FirstName',\n",
       "   'LastName',\n",
       "   'Company',\n",
       "   'Address',\n",
       "   'City',\n",
       "   'State',\n",
       "   'Country',\n",
       "   'PostalCode',\n",
       "   'Phone',\n",
       "   'Fax',\n",
       "   'Email',\n",
       "   'SupportRepId']},\n",
       " {'table_name': 'employees',\n",
       "  'column_names': ['EmployeeId',\n",
       "   'LastName',\n",
       "   'FirstName',\n",
       "   'Title',\n",
       "   'ReportsTo',\n",
       "   'BirthDate',\n",
       "   'HireDate',\n",
       "   'Address',\n",
       "   'City',\n",
       "   'State',\n",
       "   'Country',\n",
       "   'PostalCode',\n",
       "   'Phone',\n",
       "   'Fax',\n",
       "   'Email']},\n",
       " {'table_name': 'genres', 'column_names': ['GenreId', 'Name']},\n",
       " {'table_name': 'invoices',\n",
       "  'column_names': ['InvoiceId',\n",
       "   'CustomerId',\n",
       "   'InvoiceDate',\n",
       "   'BillingAddress',\n",
       "   'BillingCity',\n",
       "   'BillingState',\n",
       "   'BillingCountry',\n",
       "   'BillingPostalCode',\n",
       "   'Total']},\n",
       " {'table_name': 'invoice_items',\n",
       "  'column_names': ['InvoiceLineId',\n",
       "   'InvoiceId',\n",
       "   'TrackId',\n",
       "   'UnitPrice',\n",
       "   'Quantity']},\n",
       " {'table_name': 'media_types', 'column_names': ['MediaTypeId', 'Name']},\n",
       " {'table_name': 'playlists', 'column_names': ['PlaylistId', 'Name']},\n",
       " {'table_name': 'playlist_track', 'column_names': ['PlaylistId', 'TrackId']},\n",
       " {'table_name': 'tracks',\n",
       "  'column_names': ['TrackId',\n",
       "   'Name',\n",
       "   'AlbumId',\n",
       "   'MediaTypeId',\n",
       "   'GenreId',\n",
       "   'Composer',\n",
       "   'Milliseconds',\n",
       "   'Bytes',\n",
       "   'UnitPrice']},\n",
       " {'table_name': 'sqlite_stat1', 'column_names': ['tbl', 'idx', 'stat']}]"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "database_schema_dict"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "ae73c9ee",
   "metadata": {},
   "source": [
    "然后，定义一个函数`ask_database`。\n",
    "\n",
    "目标是让 GPT 模型帮我们构造一个完整的 SQL 查询。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "0258813a",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 定义一个功能列表，其中包含一个功能字典，该字典定义了一个名为\"ask_database\"的功能，用于回答用户关于音乐的问题\n",
    "functions = [\n",
    "    {\n",
    "        \"name\": \"ask_database\",\n",
    "        \"description\": \"Use this function to answer user questions about music. Output should be a fully formed SQL query.\",\n",
    "        \"parameters\": {\n",
    "            \"type\": \"object\",\n",
    "            \"properties\": {\n",
    "                \"query\": {\n",
    "                    \"type\": \"string\",\n",
    "                    \"description\": f\"\"\"\n",
    "                            SQL query extracting info to answer the user's question.\n",
    "                            SQL should be written using this database schema:\n",
    "                            {database_schema_string}\n",
    "                            The query should be returned in plain text, not in JSON.\n",
    "                            \"\"\",\n",
    "                }\n",
    "            },\n",
    "            \"required\": [\"query\"],\n",
    "        },\n",
    "    }\n",
    "]\n"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "da08c121",
   "metadata": {},
   "source": [
    "### 执行 SQL 查询\n",
    "\n",
    "首先，定义两个函数`ask_database`和`execute_function_call`\n",
    "- 前者用于实际执行 SQL 查询并返回结果\n",
    "- 后者用于根据消息中的功能调用信息来执行相应的功能并获取结果\n",
    "\n",
    "然后，创建一个消息列表，并向其中添加了一个系统消息和一个用户消息。系统消息的内容是指示对话的目标，用户消息的内容是用户的问题。\n",
    "\n",
    "接着，使用`chat_completion_request`函数发出聊天请求并获取响应，然后从响应中提取出助手的消息并添加到消息列表中。\n",
    "\n",
    "如果助手的消息中包含功能调用，那么就使用`execute_function_call`函数执行这个功能调用并获取结果，然后将结果作为一个功能消息添加到消息列表中。\n",
    "\n",
    "最后，使用`pretty_print_conversation`函数打印出整个对话。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "3584209d-f2b3-4ece-ab23-78636cbcfa50",
   "metadata": {},
   "outputs": [],
   "source": [
    "def ask_database(conn, query):\n",
    "    \"\"\"使用 query 来查询 SQLite 数据库的函数。\"\"\"\n",
    "    try:\n",
    "        results = str(conn.execute(query).fetchall())  # 执行查询，并将结果转换为字符串\n",
    "    except Exception as e:  # 如果查询失败，捕获异常并返回错误信息\n",
    "        results = f\"query failed with error: {e}\"\n",
    "    return results  # 返回查询结果\n",
    "\n",
    "\n",
    "def execute_function_call(message):\n",
    "    \"\"\"执行函数调用\"\"\"\n",
    "    # 判断功能调用的名称是否为 \"ask_database\"\n",
    "    if message[\"function_call\"][\"name\"] == \"ask_database\":\n",
    "        # 如果是，则获取功能调用的参数，这里是 SQL 查询\n",
    "        query = json.loads(message[\"function_call\"][\"arguments\"])[\"query\"]\n",
    "        # 使用 ask_database 函数执行查询，并获取结果\n",
    "        results = ask_database(conn, query)\n",
    "    else:\n",
    "        # 如果功能调用的名称不是 \"ask_database\"，则返回错误信息\n",
    "        results = f\"Error: function {message['function_call']['name']} does not exist\"\n",
    "    return results  # 返回结果"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "65585e74",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\u001b[31msystem: Answer user questions by generating SQL queries against the Chinook Music Database.\n",
      "\u001b[0m\n",
      "\u001b[32muser: Hi, who are the top 5 artists by number of tracks?\n",
      "\u001b[0m\n",
      "\u001b[34massistant[function_call]: {'name': 'ask_database', 'arguments': '{\\n  \"query\": \"SELECT artists.Name, COUNT(tracks.TrackId) as TrackCount FROM artists JOIN albums ON artists.ArtistId = albums.ArtistId JOIN tracks ON albums.AlbumId = tracks.AlbumId GROUP BY artists.ArtistId ORDER BY TrackCount DESC LIMIT 5;\"\\n}'}\n",
      "\u001b[0m\n",
      "\u001b[35mfunction (ask_database): [('Iron Maiden', 213), ('U2', 135), ('Led Zeppelin', 114), ('Metallica', 112), ('Deep Purple', 92)]\n",
      "\u001b[0m\n"
     ]
    }
   ],
   "source": [
    "# 创建一个空的消息列表\n",
    "messages = []\n",
    "\n",
    "# 向消息列表中添加一个系统角色的消息，内容是 \"Answer user questions by generating SQL queries against the Chinook Music Database.\"\n",
    "messages.append({\"role\": \"system\", \"content\": \"Answer user questions by generating SQL queries against the Chinook Music Database.\"})\n",
    "\n",
    "# 向消息列表中添加一个用户角色的消息，内容是 \"Hi, who are the top 5 artists by number of tracks?\"\n",
    "messages.append({\"role\": \"user\", \"content\": \"Hi, who are the top 5 artists by number of tracks?\"})\n",
    "\n",
    "# 使用 chat_completion_request 函数获取聊天响应\n",
    "chat_response = chat_completion_request(messages, functions)\n",
    "\n",
    "# 从聊天响应中获取助手的消息\n",
    "assistant_message = chat_response.json()[\"choices\"][0][\"message\"]\n",
    "\n",
    "# 将助手的消息添加到消息列表中\n",
    "messages.append(assistant_message)\n",
    "\n",
    "# 如果助手的消息中有功能调用\n",
    "if assistant_message.get(\"function_call\"):\n",
    "    # 使用 execute_function_call 函数执行功能调用，并获取结果\n",
    "    results = execute_function_call(assistant_message)\n",
    "    # 将功能的结果作为一个功能角色的消息添加到消息列表中\n",
    "    messages.append({\"role\": \"function\", \"name\": assistant_message[\"function_call\"][\"name\"], \"content\": results})\n",
    "\n",
    "# 使用 pretty_print_conversation 函数打印对话\n",
    "pretty_print_conversation(messages)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "8c986892-6aea-47c9-9114-80b4c3ae1156",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[{'table_name': 'albums', 'column_names': ['AlbumId', 'Title', 'ArtistId']},\n",
       " {'table_name': 'sqlite_sequence', 'column_names': ['name', 'seq']},\n",
       " {'table_name': 'artists', 'column_names': ['ArtistId', 'Name']},\n",
       " {'table_name': 'customers',\n",
       "  'column_names': ['CustomerId',\n",
       "   'FirstName',\n",
       "   'LastName',\n",
       "   'Company',\n",
       "   'Address',\n",
       "   'City',\n",
       "   'State',\n",
       "   'Country',\n",
       "   'PostalCode',\n",
       "   'Phone',\n",
       "   'Fax',\n",
       "   'Email',\n",
       "   'SupportRepId']},\n",
       " {'table_name': 'employees',\n",
       "  'column_names': ['EmployeeId',\n",
       "   'LastName',\n",
       "   'FirstName',\n",
       "   'Title',\n",
       "   'ReportsTo',\n",
       "   'BirthDate',\n",
       "   'HireDate',\n",
       "   'Address',\n",
       "   'City',\n",
       "   'State',\n",
       "   'Country',\n",
       "   'PostalCode',\n",
       "   'Phone',\n",
       "   'Fax',\n",
       "   'Email']},\n",
       " {'table_name': 'genres', 'column_names': ['GenreId', 'Name']},\n",
       " {'table_name': 'invoices',\n",
       "  'column_names': ['InvoiceId',\n",
       "   'CustomerId',\n",
       "   'InvoiceDate',\n",
       "   'BillingAddress',\n",
       "   'BillingCity',\n",
       "   'BillingState',\n",
       "   'BillingCountry',\n",
       "   'BillingPostalCode',\n",
       "   'Total']},\n",
       " {'table_name': 'invoice_items',\n",
       "  'column_names': ['InvoiceLineId',\n",
       "   'InvoiceId',\n",
       "   'TrackId',\n",
       "   'UnitPrice',\n",
       "   'Quantity']},\n",
       " {'table_name': 'media_types', 'column_names': ['MediaTypeId', 'Name']},\n",
       " {'table_name': 'playlists', 'column_names': ['PlaylistId', 'Name']},\n",
       " {'table_name': 'playlist_track', 'column_names': ['PlaylistId', 'TrackId']},\n",
       " {'table_name': 'tracks',\n",
       "  'column_names': ['TrackId',\n",
       "   'Name',\n",
       "   'AlbumId',\n",
       "   'MediaTypeId',\n",
       "   'GenreId',\n",
       "   'Composer',\n",
       "   'Milliseconds',\n",
       "   'Bytes',\n",
       "   'UnitPrice']},\n",
       " {'table_name': 'sqlite_stat1', 'column_names': ['tbl', 'idx', 'stat']}]"
      ]
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "database_schema_dict"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "b7f47450-52e9-40e6-9426-e5c7ded20d74",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\u001b[31msystem: Answer user questions by generating SQL queries against the Chinook Music Database.\n",
      "\u001b[0m\n",
      "\u001b[32muser: Hi, who are the top 5 artists by number of tracks?\n",
      "\u001b[0m\n",
      "\u001b[34massistant[function_call]: {'name': 'ask_database', 'arguments': '{\\n  \"query\": \"SELECT artists.Name, COUNT(tracks.TrackId) as TrackCount FROM artists JOIN albums ON artists.ArtistId = albums.ArtistId JOIN tracks ON albums.AlbumId = tracks.AlbumId GROUP BY artists.ArtistId ORDER BY TrackCount DESC LIMIT 5;\"\\n}'}\n",
      "\u001b[0m\n",
      "\u001b[35mfunction (ask_database): [('Iron Maiden', 213), ('U2', 135), ('Led Zeppelin', 114), ('Metallica', 112), ('Deep Purple', 92)]\n",
      "\u001b[0m\n",
      "\u001b[32muser: What is the name of the album with the most tracks?\n",
      "\u001b[0m\n",
      "\u001b[34massistant[function_call]: {'name': 'ask_database', 'arguments': '{\\n  \"query\": \"SELECT albums.Title, COUNT(tracks.TrackId) as TrackCount FROM albums JOIN tracks ON albums.AlbumId = tracks.AlbumId GROUP BY albums.AlbumId ORDER BY TrackCount DESC LIMIT 1;\"\\n}'}\n",
      "\u001b[0m\n",
      "\u001b[35mfunction (ask_database): [('Greatest Hits', 57)]\n",
      "\u001b[0m\n"
     ]
    }
   ],
   "source": [
    "# 向消息列表中添加一个用户的问题，内容是 \"What is the name of the album with the most tracks?\"\n",
    "messages.append({\"role\": \"user\", \"content\": \"What is the name of the album with the most tracks?\"})\n",
    "\n",
    "# 使用 chat_completion_request 函数获取聊天响应\n",
    "chat_response = chat_completion_request(messages, functions)\n",
    "\n",
    "# 从聊天响应中获取助手的消息\n",
    "assistant_message = chat_response.json()[\"choices\"][0][\"message\"]\n",
    "\n",
    "# 将助手的消息添加到消息列表中\n",
    "messages.append(assistant_message)\n",
    "\n",
    "# 如果助手的消息中有功能调用\n",
    "if assistant_message.get(\"function_call\"):\n",
    "    # 使用 execute_function_call 函数执行功能调用，并获取结果\n",
    "    results = execute_function_call(assistant_message)\n",
    "    # 将功能的结果作为一个功能角色的消息添加到消息列表中\n",
    "    messages.append({\"role\": \"function\", \"content\": results, \"name\": assistant_message[\"function_call\"][\"name\"]})\n",
    "\n",
    "# 使用 pretty_print_conversation 函数打印对话\n",
    "pretty_print_conversation(messages)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6e9f31f8-4404-434d-917b-16bdb7a08ea3",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "b5c5dc6e-0a61-46a1-b04f-03f23c9161b2",
   "metadata": {},
   "source": [
    "## Homework 1\n",
    "\n",
    "**使用第三方天气查询API，实现完整可执行的天气查询应用**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "69af411e-26ce-4c04-845e-180744e7e64f",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
